{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "Untitled1.ipynb",
      "provenance": [],
      "machine_shape": "hm",
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/jalammar/jalammar.github.io/blob/master/notebooks/distilbert_checkpoint.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fvFvBLJV0Dkv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "import torch\n",
        "from sklearn.linear_model import LogisticRegression\n",
        "from sklearn.model_selection import cross_val_score\n",
        "from sklearn.model_selection import train_test_split"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "To9ENLU90WGl",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 632
        },
        "outputId": "6773d3fc-350a-4bdf-bc85-d2a2ed5b95df"
      },
      "source": [
        "!pip install transformers"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Collecting transformers\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/fd/f9/51824e40f0a23a49eab4fcaa45c1c797cbf9761adedd0b558dab7c958b34/transformers-2.1.1-py3-none-any.whl (311kB)\n",
            "\r\u001b[K     |█                               | 10kB 14.5MB/s eta 0:00:01\r\u001b[K     |██                              | 20kB 2.2MB/s eta 0:00:01\r\u001b[K     |███▏                            | 30kB 3.2MB/s eta 0:00:01\r\u001b[K     |████▏                           | 40kB 2.1MB/s eta 0:00:01\r\u001b[K     |█████▎                          | 51kB 2.6MB/s eta 0:00:01\r\u001b[K     |██████▎                         | 61kB 3.0MB/s eta 0:00:01\r\u001b[K     |███████▍                        | 71kB 3.5MB/s eta 0:00:01\r\u001b[K     |████████▍                       | 81kB 3.9MB/s eta 0:00:01\r\u001b[K     |█████████▌                      | 92kB 4.3MB/s eta 0:00:01\r\u001b[K     |██████████▌                     | 102kB 3.4MB/s eta 0:00:01\r\u001b[K     |███████████▋                    | 112kB 3.4MB/s eta 0:00:01\r\u001b[K     |████████████▋                   | 122kB 3.4MB/s eta 0:00:01\r\u001b[K     |█████████████▊                  | 133kB 3.4MB/s eta 0:00:01\r\u001b[K     |██████████████▊                 | 143kB 3.4MB/s eta 0:00:01\r\u001b[K     |███████████████▊                | 153kB 3.4MB/s eta 0:00:01\r\u001b[K     |████████████████▉               | 163kB 3.4MB/s eta 0:00:01\r\u001b[K     |█████████████████▉              | 174kB 3.4MB/s eta 0:00:01\r\u001b[K     |███████████████████             | 184kB 3.4MB/s eta 0:00:01\r\u001b[K     |████████████████████            | 194kB 3.4MB/s eta 0:00:01\r\u001b[K     |█████████████████████           | 204kB 3.4MB/s eta 0:00:01\r\u001b[K     |██████████████████████          | 215kB 3.4MB/s eta 0:00:01\r\u001b[K     |███████████████████████▏        | 225kB 3.4MB/s eta 0:00:01\r\u001b[K     |████████████████████████▏       | 235kB 3.4MB/s eta 0:00:01\r\u001b[K     |█████████████████████████▎      | 245kB 3.4MB/s eta 0:00:01\r\u001b[K     |██████████████████████████▎     | 256kB 3.4MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▍    | 266kB 3.4MB/s eta 0:00:01\r\u001b[K     |████████████████████████████▍   | 276kB 3.4MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████▍  | 286kB 3.4MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▌ | 296kB 3.4MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▌| 307kB 3.4MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 317kB 3.4MB/s \n",
            "\u001b[?25hRequirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from transformers) (4.28.1)\n",
            "Collecting sentencepiece\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/14/3d/efb655a670b98f62ec32d66954e1109f403db4d937c50d779a75b9763a29/sentencepiece-0.1.83-cp36-cp36m-manylinux1_x86_64.whl (1.0MB)\n",
            "\r\u001b[K     |▎                               | 10kB 14.5MB/s eta 0:00:01\r\u001b[K     |▋                               | 20kB 20.6MB/s eta 0:00:01\r\u001b[K     |█                               | 30kB 27.0MB/s eta 0:00:01\r\u001b[K     |█▎                              | 40kB 31.5MB/s eta 0:00:01\r\u001b[K     |█▋                              | 51kB 34.9MB/s eta 0:00:01\r\u001b[K     |██                              | 61kB 39.0MB/s eta 0:00:01\r\u001b[K     |██▏                             | 71kB 41.5MB/s eta 0:00:01\r\u001b[K     |██▌                             | 81kB 42.8MB/s eta 0:00:01\r\u001b[K     |██▉                             | 92kB 45.1MB/s eta 0:00:01\r\u001b[K     |███▏                            | 102kB 47.4MB/s eta 0:00:01\r\u001b[K     |███▌                            | 112kB 47.4MB/s eta 0:00:01\r\u001b[K     |███▉                            | 122kB 47.4MB/s eta 0:00:01\r\u001b[K     |████                            | 133kB 47.4MB/s eta 0:00:01\r\u001b[K     |████▍                           | 143kB 47.4MB/s eta 0:00:01\r\u001b[K     |████▊                           | 153kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████                           | 163kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████▍                          | 174kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████▊                          | 184kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████                          | 194kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████▎                         | 204kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████▋                         | 215kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████                         | 225kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████▎                        | 235kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████▋                        | 245kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████▉                        | 256kB 47.4MB/s eta 0:00:01\r\u001b[K     |████████▏                       | 266kB 47.4MB/s eta 0:00:01\r\u001b[K     |████████▌                       | 276kB 47.4MB/s eta 0:00:01\r\u001b[K     |████████▉                       | 286kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████▏                      | 296kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████▌                      | 307kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████▊                      | 317kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████████                      | 327kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████████▍                     | 337kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████████▊                     | 348kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████████                     | 358kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████████▍                    | 368kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████████▋                    | 378kB 47.4MB/s eta 0:00:01\r\u001b[K     |████████████                    | 389kB 47.4MB/s eta 0:00:01\r\u001b[K     |████████████▎                   | 399kB 47.4MB/s eta 0:00:01\r\u001b[K     |████████████▋                   | 409kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████████                   | 419kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████████▎                  | 430kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████████▌                  | 440kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████████▉                  | 450kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████████████▏                 | 460kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████████████▌                 | 471kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████████████▉                 | 481kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████████████▏                | 491kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████████████▍                | 501kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████████████▊                | 512kB 47.4MB/s eta 0:00:01\r\u001b[K     |████████████████                | 522kB 47.4MB/s eta 0:00:01\r\u001b[K     |████████████████▍               | 532kB 47.4MB/s eta 0:00:01\r\u001b[K     |████████████████▊               | 542kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████████████               | 552kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████████████▎              | 563kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████████████▋              | 573kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████████████████              | 583kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████████████████▎             | 593kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████████████████▋             | 604kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████████████████             | 614kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████████████████▏            | 624kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████████████████▌            | 634kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████████████████▉            | 645kB 47.4MB/s eta 0:00:01\r\u001b[K     |████████████████████▏           | 655kB 47.4MB/s eta 0:00:01\r\u001b[K     |████████████████████▌           | 665kB 47.4MB/s eta 0:00:01\r\u001b[K     |████████████████████▉           | 675kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████████████████▏          | 686kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████████████████▍          | 696kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████████████████▊          | 706kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████████████████████          | 716kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████████████████████▍         | 727kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████████████████████▊         | 737kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████████████████████         | 747kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████████████████████▎        | 757kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████████████████████▋        | 768kB 47.4MB/s eta 0:00:01\r\u001b[K     |████████████████████████        | 778kB 47.4MB/s eta 0:00:01\r\u001b[K     |████████████████████████▎       | 788kB 47.4MB/s eta 0:00:01\r\u001b[K     |████████████████████████▋       | 798kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████████████████████       | 808kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████████████████████▏      | 819kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████████████████████▌      | 829kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████████████████████▉      | 839kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████████████████████████▏     | 849kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████████████████████████▌     | 860kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████████████████████████▉     | 870kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████████████████████████     | 880kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▍    | 890kB 47.4MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▊    | 901kB 47.4MB/s eta 0:00:01\r\u001b[K     |████████████████████████████    | 911kB 47.4MB/s eta 0:00:01\r\u001b[K     |████████████████████████████▍   | 921kB 47.4MB/s eta 0:00:01\r\u001b[K     |████████████████████████████▊   | 931kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████   | 942kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████▎  | 952kB 47.4MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████▋  | 962kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████  | 972kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▎ | 983kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▋ | 993kB 47.4MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▉ | 1.0MB 47.4MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▏| 1.0MB 47.4MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▌| 1.0MB 47.4MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▉| 1.0MB 47.4MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 1.0MB 47.4MB/s \n",
            "\u001b[?25hCollecting sacremoses\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/1f/8e/ed5364a06a9ba720fddd9820155cc57300d28f5f43a6fd7b7e817177e642/sacremoses-0.0.35.tar.gz (859kB)\n",
            "\r\u001b[K     |▍                               | 10kB 15.3MB/s eta 0:00:01\r\u001b[K     |▊                               | 20kB 21.7MB/s eta 0:00:01\r\u001b[K     |█▏                              | 30kB 28.1MB/s eta 0:00:01\r\u001b[K     |█▌                              | 40kB 33.4MB/s eta 0:00:01\r\u001b[K     |██                              | 51kB 35.7MB/s eta 0:00:01\r\u001b[K     |██▎                             | 61kB 39.4MB/s eta 0:00:01\r\u001b[K     |██▊                             | 71kB 41.9MB/s eta 0:00:01\r\u001b[K     |███                             | 81kB 43.8MB/s eta 0:00:01\r\u001b[K     |███▍                            | 92kB 46.5MB/s eta 0:00:01\r\u001b[K     |███▉                            | 102kB 48.8MB/s eta 0:00:01\r\u001b[K     |████▏                           | 112kB 48.8MB/s eta 0:00:01\r\u001b[K     |████▋                           | 122kB 48.8MB/s eta 0:00:01\r\u001b[K     |█████                           | 133kB 48.8MB/s eta 0:00:01\r\u001b[K     |█████▍                          | 143kB 48.8MB/s eta 0:00:01\r\u001b[K     |█████▊                          | 153kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████                          | 163kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████▌                         | 174kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████▉                         | 184kB 48.8MB/s eta 0:00:01\r\u001b[K     |███████▎                        | 194kB 48.8MB/s eta 0:00:01\r\u001b[K     |███████▋                        | 204kB 48.8MB/s eta 0:00:01\r\u001b[K     |████████                        | 215kB 48.8MB/s eta 0:00:01\r\u001b[K     |████████▍                       | 225kB 48.8MB/s eta 0:00:01\r\u001b[K     |████████▊                       | 235kB 48.8MB/s eta 0:00:01\r\u001b[K     |█████████▏                      | 245kB 48.8MB/s eta 0:00:01\r\u001b[K     |█████████▌                      | 256kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████████                      | 266kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████████▎                     | 276kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████████▊                     | 286kB 48.8MB/s eta 0:00:01\r\u001b[K     |███████████                     | 296kB 48.8MB/s eta 0:00:01\r\u001b[K     |███████████▍                    | 307kB 48.8MB/s eta 0:00:01\r\u001b[K     |███████████▉                    | 317kB 48.8MB/s eta 0:00:01\r\u001b[K     |████████████▏                   | 327kB 48.8MB/s eta 0:00:01\r\u001b[K     |████████████▋                   | 337kB 48.8MB/s eta 0:00:01\r\u001b[K     |█████████████                   | 348kB 48.8MB/s eta 0:00:01\r\u001b[K     |█████████████▍                  | 358kB 48.8MB/s eta 0:00:01\r\u001b[K     |█████████████▊                  | 368kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████████████                  | 378kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████████████▌                 | 389kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████████████▉                 | 399kB 48.8MB/s eta 0:00:01\r\u001b[K     |███████████████▎                | 409kB 48.8MB/s eta 0:00:01\r\u001b[K     |███████████████▋                | 419kB 48.8MB/s eta 0:00:01\r\u001b[K     |████████████████                | 430kB 48.8MB/s eta 0:00:01\r\u001b[K     |████████████████▍               | 440kB 48.8MB/s eta 0:00:01\r\u001b[K     |████████████████▊               | 450kB 48.8MB/s eta 0:00:01\r\u001b[K     |█████████████████▏              | 460kB 48.8MB/s eta 0:00:01\r\u001b[K     |█████████████████▌              | 471kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████████████████              | 481kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████████████████▎             | 491kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████████████████▊             | 501kB 48.8MB/s eta 0:00:01\r\u001b[K     |███████████████████             | 512kB 48.8MB/s eta 0:00:01\r\u001b[K     |███████████████████▍            | 522kB 48.8MB/s eta 0:00:01\r\u001b[K     |███████████████████▉            | 532kB 48.8MB/s eta 0:00:01\r\u001b[K     |████████████████████▏           | 542kB 48.8MB/s eta 0:00:01\r\u001b[K     |████████████████████▋           | 552kB 48.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████           | 563kB 48.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████▍          | 573kB 48.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████▊          | 583kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████          | 593kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████▌         | 604kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████▉         | 614kB 48.8MB/s eta 0:00:01\r\u001b[K     |███████████████████████▎        | 624kB 48.8MB/s eta 0:00:01\r\u001b[K     |███████████████████████▋        | 634kB 48.8MB/s eta 0:00:01\r\u001b[K     |████████████████████████        | 645kB 48.8MB/s eta 0:00:01\r\u001b[K     |████████████████████████▍       | 655kB 48.8MB/s eta 0:00:01\r\u001b[K     |████████████████████████▊       | 665kB 48.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████████▏      | 675kB 48.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████████▌      | 686kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████████      | 696kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████████▎     | 706kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████████▊     | 716kB 48.8MB/s eta 0:00:01\r\u001b[K     |███████████████████████████     | 727kB 48.8MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▍    | 737kB 48.8MB/s eta 0:00:01\r\u001b[K     |███████████████████████████▉    | 747kB 48.8MB/s eta 0:00:01\r\u001b[K     |████████████████████████████▏   | 757kB 48.8MB/s eta 0:00:01\r\u001b[K     |████████████████████████████▋   | 768kB 48.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████   | 778kB 48.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████▍  | 788kB 48.8MB/s eta 0:00:01\r\u001b[K     |█████████████████████████████▊  | 798kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████  | 808kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▌ | 819kB 48.8MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▉ | 829kB 48.8MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▎| 839kB 48.8MB/s eta 0:00:01\r\u001b[K     |███████████████████████████████▋| 849kB 48.8MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 860kB 48.8MB/s \n",
            "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.17.3)\n",
            "Collecting regex\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/e3/8e/cbf2295643d7265e7883326fb4654e643bfc93b3a8a8274d8010a39d8804/regex-2019.11.1-cp36-cp36m-manylinux1_x86_64.whl (643kB)\n",
            "\u001b[K     |████████████████████████████████| 645kB 43.5MB/s \n",
            "\u001b[?25hRequirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from transformers) (1.10.4)\n",
            "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.21.0)\n",
            "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.12.0)\n",
            "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.0)\n",
            "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.14.0)\n",
            "Requirement already satisfied: s3transfer<0.3.0,>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.2.1)\n",
            "Requirement already satisfied: botocore<1.14.0,>=1.13.4 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (1.13.4)\n",
            "Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->transformers) (0.9.4)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2019.9.11)\n",
            "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)\n",
            "Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.8)\n",
            "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)\n",
            "Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.14.0,>=1.13.4->boto3->transformers) (0.15.2)\n",
            "Requirement already satisfied: python-dateutil<3.0.0,>=2.1; python_version >= \"2.7\" in /usr/local/lib/python3.6/dist-packages (from botocore<1.14.0,>=1.13.4->boto3->transformers) (2.6.1)\n",
            "Building wheels for collected packages: sacremoses\n",
            "  Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
            "  Created wheel for sacremoses: filename=sacremoses-0.0.35-cp36-none-any.whl size=883999 sha256=60135dbf1248bf3a94375ddbeec791a41ffa74f0346a7f96d2c25d41f0608ce1\n",
            "  Stored in directory: /root/.cache/pip/wheels/63/2a/db/63e2909042c634ef551d0d9ac825b2b0b32dede4a6d87ddc94\n",
            "Successfully built sacremoses\n",
            "Installing collected packages: sentencepiece, sacremoses, regex, transformers\n",
            "Successfully installed regex-2019.11.1 sacremoses-0.0.35 sentencepiece-0.1.83 transformers-2.1.1\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "POT8e8VX0Ry1",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 63
        },
        "outputId": "9ebfb397-6657-4ad8-ee0e-3496be561d26"
      },
      "source": [
        "import transformers as ppb"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "<p style=\"color: red;\">\n",
              "The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x.<br>\n",
              "We recommend you <a href=\"https://www.tensorflow.org/guide/migrate\" target=\"_blank\">upgrade</a> now \n",
              "or ensure your notebook will continue to use TensorFlow 1.x via the <code>%tensorflow_version 1.x</code> magic:\n",
              "<a href=\"https://colab.research.google.com/notebooks/tensorflow_version.ipynb\" target=\"_blank\">more info</a>.</p>\n"
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cyoj29J24hPX",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "df = pd.read_csv('https://github.com/clairett/pytorch-sentiment-classification/raw/master/data/SST2/train.tsv', delimiter='\\t', header=None)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VDSUXdbRatJc",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 204
        },
        "outputId": "54cfb455-fc99-494d-d692-6bcde6e1aa36"
      },
      "source": [
        "df.head()"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>0</th>\n",
              "      <th>1</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>a stirring , funny and finally transporting re...</td>\n",
              "      <td>1</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>apparently reassembled from the cutting room f...</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>they presume their audience wo n't sit still f...</td>\n",
              "      <td>0</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>this is a visually stunning rumination on love...</td>\n",
              "      <td>1</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>jonathan parker 's bartleby should have been t...</td>\n",
              "      <td>1</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "                                                   0  1\n",
              "0  a stirring , funny and finally transporting re...  1\n",
              "1  apparently reassembled from the cutting room f...  0\n",
              "2  they presume their audience wo n't sit still f...  0\n",
              "3  this is a visually stunning rumination on love...  1\n",
              "4  jonathan parker 's bartleby should have been t...  1"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "wrfxiG7Qa2ce",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 139
        },
        "outputId": "5b9b6ad7-0a91-4855-f13e-b1d7ee05b028"
      },
      "source": [
        "df[:5][0].values"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "array(['a stirring , funny and finally transporting re imagining of beauty and the beast and 1930s horror films',\n",
              "       'apparently reassembled from the cutting room floor of any given daytime soap',\n",
              "       \"they presume their audience wo n't sit still for a sociology lesson , however entertainingly presented , so they trot out the conventional science fiction elements of bug eyed monsters and futuristic women in skimpy clothes\",\n",
              "       'this is a visually stunning rumination on love , memory , history and the war between art and commerce',\n",
              "       \"jonathan parker 's bartleby should have been the be all end all of the modern office anomie films\"],\n",
              "      dtype=object)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 6
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gTM3hOHW4hUY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "batch_1 = df[:2000]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jGvcfcCP5xpZ",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 68
        },
        "outputId": "dc9effb8-81a2-4c45-96a5-74c03af41304"
      },
      "source": [
        "batch_1[1].value_counts()"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "1    1041\n",
              "0     959\n",
              "Name: 1, dtype: int64"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 8
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "q1InADgf5xm2",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 68
        },
        "outputId": "b34db792-a727-4e45-b994-14f085a00f98"
      },
      "source": [
        "model_class, tokenizer_class, pretrained_weights = (ppb.DistilBertModel, ppb.DistilBertTokenizer, 'distilbert-base-uncased')\n",
        "\n",
        "# Load pretrained model/tokenizer\n",
        "tokenizer = tokenizer_class.from_pretrained(pretrained_weights)\n",
        "model = model_class.from_pretrained(pretrained_weights)"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "100%|██████████| 231508/231508 [00:00<00:00, 2070438.45B/s]\n",
            "100%|██████████| 492/492 [00:00<00:00, 96380.25B/s]\n",
            "100%|██████████| 267967963/267967963 [00:05<00:00, 51387886.41B/s]\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CXAVK78V7x4r",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 119
        },
        "outputId": "3118375f-b669-4f40-ef22-8acaaf96dccd"
      },
      "source": [
        "df[0][:5]"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0    a stirring , funny and finally transporting re...\n",
              "1    apparently reassembled from the cutting room f...\n",
              "2    they presume their audience wo n't sit still f...\n",
              "3    this is a visually stunning rumination on love...\n",
              "4    jonathan parker 's bartleby should have been t...\n",
              "Name: 0, dtype: object"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 10
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RIS4tx297oaL",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "17c00929-56f4-4b04-c6a2-257fd8338cd4"
      },
      "source": [
        "tokenizer.encode(\"a visually stunning rumination on love\", add_special_tokens=True)"
      ],
      "execution_count": 11,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[101, 1037, 17453, 14726, 19379, 12758, 2006, 2293, 102]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 11
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "W9I55qzK8R1X",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "956c9337-2062-4cf6-85c4-26af8444538e"
      },
      "source": [
        "tokenizer.decode(102)"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "'[ S E P ]'"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 12
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8UPTTvat8R77",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 477
        },
        "outputId": "f320cbbe-f501-4a39-cd2c-d0a0d1ef662a"
      },
      "source": [
        "tokenized_5[0], tokenized_5[1], tokenized_5[1999]"
      ],
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "error",
          "ename": "KeyError",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-22-56cd2d47136f>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtokenized_5\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokenized_5\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtokenized_5\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1999\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pandas/core/series.py\u001b[0m in \u001b[0;36m__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m   1069\u001b[0m         \u001b[0mkey\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcom\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mapply_if_callable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1070\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1071\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mindex\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1072\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1073\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mis_scalar\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/pandas/core/indexes/base.py\u001b[0m in \u001b[0;36mget_value\u001b[0;34m(self, series, key)\u001b[0m\n\u001b[1;32m   4728\u001b[0m         \u001b[0mk\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_convert_scalar_indexer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkind\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"getitem\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   4729\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 4730\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_engine\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtz\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseries\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"tz\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   4731\u001b[0m         \u001b[0;32mexcept\u001b[0m \u001b[0mKeyError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   4732\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mholds_integer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_boolean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32mpandas/_libs/index.pyx\u001b[0m in \u001b[0;36mpandas._libs.index.IndexEngine.get_value\u001b[0;34m()\u001b[0m\n",
            "\u001b[0;32mpandas/_libs/index.pyx\u001b[0m in \u001b[0;36mpandas._libs.index.IndexEngine.get_value\u001b[0;34m()\u001b[0m\n",
            "\u001b[0;32mpandas/_libs/index.pyx\u001b[0m in \u001b[0;36mpandas._libs.index.IndexEngine.get_loc\u001b[0;34m()\u001b[0m\n",
            "\u001b[0;32mpandas/_libs/hashtable_class_helper.pxi\u001b[0m in \u001b[0;36mpandas._libs.hashtable.Int64HashTable.get_item\u001b[0;34m()\u001b[0m\n",
            "\u001b[0;32mpandas/_libs/hashtable_class_helper.pxi\u001b[0m in \u001b[0;36mpandas._libs.hashtable.Int64HashTable.get_item\u001b[0;34m()\u001b[0m\n",
            "\u001b[0;31mKeyError\u001b[0m: 1999"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zlIMCS7Z8GlR",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "sentence = tokenizer.encode(\"a visually stunning rumination on love\", add_special_tokens=True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HJwh2XqkIFTM",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "tokenized_5 = df[0][:5].apply((lambda x: tokenizer.encode(x, add_special_tokens=True)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Dg82ndBA5xlN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "tokenized = batch_1[0].apply((lambda x: tokenizer.encode(x, add_special_tokens=True)))"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cb53a--qLHY_",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 544
        },
        "outputId": "b34621cc-53bd-488c-c03d-e48504096d92"
      },
      "source": [
        "tokenized[1999]"
      ],
      "execution_count": 21,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[101,\n",
              " 1996,\n",
              " 3185,\n",
              " 2003,\n",
              " 25757,\n",
              " 2011,\n",
              " 1037,\n",
              " 24466,\n",
              " 16134,\n",
              " 2008,\n",
              " 1005,\n",
              " 1055,\n",
              " 2074,\n",
              " 6388,\n",
              " 2438,\n",
              " 2000,\n",
              " 7344,\n",
              " 3686,\n",
              " 1996,\n",
              " 7731,\n",
              " 4378,\n",
              " 2096,\n",
              " 13060,\n",
              " 18856,\n",
              " 17322,\n",
              " 2094,\n",
              " 2000,\n",
              " 15015,\n",
              " 10271,\n",
              " 4641,\n",
              " 102]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 21
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "URn-DWJt5xhP",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "450c75f0-adc4-412c-bf14-70c79b970606"
      },
      "source": [
        "max_len = 0\n",
        "for i in tokenized.values:\n",
        "    if len(i) > max_len:\n",
        "        max_len = len(i)\n",
        "max_len"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "59"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 14
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jdi7uXo95xeq",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "padded = [i + [0]*(max_len-len(i)) for i in tokenized.values]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "5lEpOI5r5xbT",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "2b768828-91da-44a1-f7c4-c23995ee2374"
      },
      "source": [
        "np.array(padded).shape"
      ],
      "execution_count": 16,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(2000, 59)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 16
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "39UVjAV56PJz",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "input_ids = torch.tensor(np.array(padded))  # Add special tokens takes care of adding [CLS], [SEP], <s>... tokens in the right way for each model.\n",
        "# input_ids = torch.tensor([sentence])  # Add special tokens takes care of adding [CLS], [SEP], <s>... tokens in the right way for each model.\n",
        "\n",
        "with torch.no_grad():\n",
        "    last_hidden_states = model(input_ids)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "C-PTTvAtC_hj",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "6fd663b9-0eea-4705-f0c6-ae14b8929fa8"
      },
      "source": [
        "last_hidden_states[0].shape"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "torch.Size([2000, 59, 768])"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 18
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "C9t60At16PVs",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "features = last_hidden_states[0][:,0,:].numpy()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JD3fX2yh6PTx",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "labels = batch_1[1]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ddAqbkoU6PP9",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from sklearn.model_selection import train_test_split\n",
        "train_features, test_features, train_labels, test_labels = train_test_split(features, labels)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8S0oNuHLu4V_",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "303b949d-e704-4ddc-e633-c01694e6166a"
      },
      "source": [
        "from sklearn.model_selection import GridSearchCV\n",
        "parameters = {'C': np.linspace(0.001, 100, 20)}\n",
        "grid_search = GridSearchCV(LogisticRegression(), parameters)\n",
        "grid_search.fit(train_features, train_labels)"
      ],
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.6/dist-packages/sklearn/model_selection/_split.py:1978: FutureWarning: The default value of cv will change from 3 to 5 in version 0.22. Specify it explicitly to silence this warning.\n",
            "  warnings.warn(CV_WARNING, FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "GridSearchCV(cv='warn', error_score='raise-deprecating',\n",
              "             estimator=LogisticRegression(C=1.0, class_weight=None, dual=False,\n",
              "                                          fit_intercept=True,\n",
              "                                          intercept_scaling=1, l1_ratio=None,\n",
              "                                          max_iter=100, multi_class='warn',\n",
              "                                          n_jobs=None, penalty='l2',\n",
              "                                          random_state=None, solver='warn',\n",
              "                                          tol=0.0001, verbose=0,\n",
              "                                          warm_start=False),\n",
              "             iid='warn', n_jobs=None,\n",
              "             param_grid={'C': array([1.00...2105e+01, 1.57903158e+01,\n",
              "       2.10534211e+01, 2.63165263e+01, 3.15796316e+01, 3.68427368e+01,\n",
              "       4.21058421e+01, 4.73689474e+01, 5.26320526e+01, 5.78951579e+01,\n",
              "       6.31582632e+01, 6.84213684e+01, 7.36844737e+01, 7.89475789e+01,\n",
              "       8.42106842e+01, 8.94737895e+01, 9.47368947e+01, 1.00000000e+02])},\n",
              "             pre_dispatch='2*n_jobs', refit=True, return_train_score=False,\n",
              "             scoring=None, verbose=0)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 22
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "OapFyXIovbFq",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "cc2e3ad1-9277-4c54-ef07-d23c703f9504"
      },
      "source": [
        "grid_search.best_params_"
      ],
      "execution_count": 24,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "{'C': 5.264105263157894}"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 24
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "pHyB8FwlvbYZ",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "46faab80-7577-432f-e272-d35cddcf2a64"
      },
      "source": [
        "grid_search.best_score_"
      ],
      "execution_count": 25,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.7986666666666666"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 25
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4TEWilVrvblN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "lr_clf = LogisticRegression(C= 78.94757894736841)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kVa-s4Jj6POm",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 190
        },
        "outputId": "41628e49-83e8-4164-e189-f78b21ea64d4"
      },
      "source": [
        "\n",
        "lr_clf = LogisticRegression(C= 5.264)\n",
        "# lr_clf.fit(train_features, train_labels)\n",
        "\n",
        "scores = cross_val_score(lr_clf, train_features, train_labels)\n",
        "print(\"Logistic Regression classifier score: %0.3f (+/- %0.2f)\" % (scores.mean(), scores.std() * 2))"
      ],
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.6/dist-packages/sklearn/model_selection/_split.py:1978: FutureWarning: The default value of cv will change from 3 to 5 in version 0.22. Specify it explicitly to silence this warning.\n",
            "  warnings.warn(CV_WARNING, FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "Logistic Regression classifier score: 0.799 (+/- 0.02)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gG-EVWx4CzBc",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 88
        },
        "outputId": "1b2e48c6-24a1-480b-df0f-f4273c27a626"
      },
      "source": [
        "lr_clf = LogisticRegression(C=5.264)\n",
        "lr_clf.fit(train_features, train_labels)\n",
        "lr_clf.score(test_features, test_labels)"
      ],
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.6/dist-packages/sklearn/linear_model/logistic.py:432: FutureWarning: Default solver will be changed to 'lbfgs' in 0.22. Specify a solver to silence this warning.\n",
            "  FutureWarning)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.81"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 27
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SB3y3iu26PHM",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 190
        },
        "outputId": "b8108bc4-37f4-4844-a478-0cca8c1bcc7a"
      },
      "source": [
        "from sklearn.ensemble import RandomForestClassifier\n",
        "random_forest = RandomForestClassifier()\n",
        "scores = cross_val_score(random_forest, train_features, train_labels)\n",
        "print(\"Random forest classifier score: %0.3f (+/- %0.2f)\" % (scores.mean(), scores.std() * 2))\n"
      ],
      "execution_count": 18,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.6/dist-packages/sklearn/model_selection/_split.py:1978: FutureWarning: The default value of cv will change from 3 to 5 in version 0.22. Specify it explicitly to silence this warning.\n",
            "  warnings.warn(CV_WARNING, FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/ensemble/forest.py:245: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22.\n",
            "  \"10 in version 0.20 to 100 in 0.22.\", FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/ensemble/forest.py:245: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22.\n",
            "  \"10 in version 0.20 to 100 in 0.22.\", FutureWarning)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "Random forest classifier score: 0.683 (+/- 0.04)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.6/dist-packages/sklearn/ensemble/forest.py:245: FutureWarning: The default value of n_estimators will change from 10 in version 0.20 to 100 in 0.22.\n",
            "  \"10 in version 0.20 to 100 in 0.22.\", FutureWarning)\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BjkK2NBZ4hJE",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 190
        },
        "outputId": "49f38852-7fe4-48ec-b25f-6e422f6aedf9"
      },
      "source": [
        "from sklearn.svm import SVC\n",
        "svc_regressor = SVC()\n",
        "\n",
        "scores = cross_val_score(svc_regressor, train_features, train_labels)\n",
        "print(\"SVM classifier average score: %0.3f (+/- %0.2f)\" % (scores.mean(), scores.std() * 2))\n"
      ],
      "execution_count": 19,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.6/dist-packages/sklearn/model_selection/_split.py:1978: FutureWarning: The default value of cv will change from 3 to 5 in version 0.22. Specify it explicitly to silence this warning.\n",
            "  warnings.warn(CV_WARNING, FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/svm/base.py:193: FutureWarning: The default value of gamma will change from 'auto' to 'scale' in version 0.22 to account better for unscaled features. Set gamma explicitly to 'auto' or 'scale' to avoid this warning.\n",
            "  \"avoid this warning.\", FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/svm/base.py:193: FutureWarning: The default value of gamma will change from 'auto' to 'scale' in version 0.22 to account better for unscaled features. Set gamma explicitly to 'auto' or 'scale' to avoid this warning.\n",
            "  \"avoid this warning.\", FutureWarning)\n",
            "/usr/local/lib/python3.6/dist-packages/sklearn/svm/base.py:193: FutureWarning: The default value of gamma will change from 'auto' to 'scale' in version 0.22 to account better for unscaled features. Set gamma explicitly to 'auto' or 'scale' to avoid this warning.\n",
            "  \"avoid this warning.\", FutureWarning)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "SVM classifier average score: 0.525 (+/- 0.00)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hTv5YpnN7iyx",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "c74a15e7-2e7b-4587-c07c-e776aff63c1f"
      },
      "source": [
        "from sklearn.ensemble import GradientBoostingClassifier\n",
        "# est = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1,max_depth=1, random_state=0, loss='ls')\n",
        "gbc = GradientBoostingClassifier()\n",
        "\n",
        "scores = cross_val_score(gbc, train_features, train_labels, cv=5)\n",
        "print(\"Gradient Boosting classifier accuracy: %0.3f (+/- %0.2f)\" % (scores.mean(), scores.std() * 2))\n",
        "\n"
      ],
      "execution_count": 20,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Gradient Boosting classifier accuracy: 0.765 (+/- 0.03)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tzQYSP-q7i7T",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "aac87a1c-0544-4618-d3b1-f290a6db3e8f"
      },
      "source": [
        "from sklearn.neural_network import MLPClassifier\n",
        "neural_network = MLPClassifier(hidden_layer_sizes=(1000, 300), learning_rate='adaptive', verbose=False)\n",
        "neural_network.fit(train_features, train_labels)\n",
        "neural_network.score(test_features, test_labels)\n",
        "# scores = cross_val_score(neural_network, train_features, train_labels)\n",
        "# print(\"Neural network regressor score: %0.3f (+/- %0.2f)\" % (scores.mean(), scores.std() * 2))\n"
      ],
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.792"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 26
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8xb3_A7cEby9",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "c31d235e-f5f3-4284-f0a2-10d583e0aaae"
      },
      "source": [
        "from sklearn.neural_network import MLPClassifier\n",
        "neural_network = MLPClassifier(hidden_layer_sizes=(2000), learning_rate='adaptive', verbose=False)\n",
        "neural_network.fit(train_features, train_labels)\n",
        "neural_network.score(test_features, test_labels)\n"
      ],
      "execution_count": 27,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.784"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 27
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lnwgmqNG7i5l",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 88
        },
        "outputId": "2989dae4-7690-463a-a582-4047fb8d196b"
      },
      "source": [
        "from sklearn.dummy import DummyClassifier\n",
        "clf = DummyClassifier()\n",
        "\n",
        "scores = cross_val_score(clf, train_features, train_labels)\n",
        "print(\"Dummy classifier score: %0.3f (+/- %0.2f)\" % (scores.mean(), scores.std() * 2))"
      ],
      "execution_count": 22,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Dummy classifier score: 0.497 (+/- 0.02)\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.6/dist-packages/sklearn/model_selection/_split.py:1978: FutureWarning: The default value of cv will change from 3 to 5 in version 0.22. Specify it explicitly to silence this warning.\n",
            "  warnings.warn(CV_WARNING, FutureWarning)\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "YYwrotgc7i31",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JDuQ6im17iv_",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GW8VoPCl0Z3q",
        "colab_type": "code",
        "colab": {
          "resources": {
            "http://localhost:8080/nbextensions/google.colab/files.js": {
              "data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7Ci8vIE1heCBhbW91bnQgb2YgdGltZSB0byBibG9jayB3YWl0aW5nIGZvciB0aGUgdXNlci4KY29uc3QgRklMRV9DSEFOR0VfVElNRU9VVF9NUyA9IDMwICogMTAwMDsKCmZ1bmN0aW9uIF91cGxvYWRGaWxlcyhpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IHN0ZXBzID0gdXBsb2FkRmlsZXNTdGVwKGlucHV0SWQsIG91dHB1dElkKTsKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIC8vIENhY2hlIHN0ZXBzIG9uIHRoZSBvdXRwdXRFbGVtZW50IHRvIG1ha2UgaXQgYXZhaWxhYmxlIGZvciB0aGUgbmV4dCBjYWxsCiAgLy8gdG8gdXBsb2FkRmlsZXNDb250aW51ZSBmcm9tIFB5dGhvbi4KICBvdXRwdXRFbGVtZW50LnN0ZXBzID0gc3RlcHM7CgogIHJldHVybiBfdXBsb2FkRmlsZXNDb250aW51ZShvdXRwdXRJZCk7Cn0KCi8vIFRoaXMgaXMgcm91Z2hseSBhbiBhc3luYyBnZW5lcmF0b3IgKG5vdCBzdXBwb3J0ZWQgaW4gdGhlIGJyb3dzZXIgeWV0KSwKLy8gd2hlcmUgdGhlcmUgYXJlIG11bHRpcGxlIGFzeW5jaHJvbm91cyBzdGVwcyBhbmQgdGhlIFB5dGhvbiBzaWRlIGlzIGdvaW5nCi8vIHRvIHBvbGwgZm9yIGNvbXBsZXRpb24gb2YgZWFjaCBzdGVwLgovLyBUaGlzIHVzZXMgYSBQcm9taXNlIHRvIGJsb2NrIHRoZSBweXRob24gc2lkZSBvbiBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcCwKLy8gdGhlbiBwYXNzZXMgdGhlIHJlc3VsdCBvZiB0aGUgcHJldmlvdXMgc3RlcCBhcyB0aGUgaW5wdXQgdG8gdGhlIG5leHQgc3RlcC4KZnVuY3Rpb24gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpIHsKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIGNvbnN0IHN0ZXBzID0gb3V0cHV0RWxlbWVudC5zdGVwczsKCiAgY29uc3QgbmV4dCA9IHN0ZXBzLm5leHQob3V0cHV0RWxlbWVudC5sYXN0UHJvbWlzZVZhbHVlKTsKICByZXR1cm4gUHJvbWlzZS5yZXNvbHZlKG5leHQudmFsdWUucHJvbWlzZSkudGhlbigodmFsdWUpID0+IHsKICAgIC8vIENhY2hlIHRoZSBsYXN0IHByb21pc2UgdmFsdWUgdG8gbWFrZSBpdCBhdmFpbGFibGUgdG8gdGhlIG5leHQKICAgIC8vIHN0ZXAgb2YgdGhlIGdlbmVyYXRvci4KICAgIG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSA9IHZhbHVlOwogICAgcmV0dXJuIG5leHQudmFsdWUucmVzcG9uc2U7CiAgfSk7Cn0KCi8qKgogKiBHZW5lcmF0b3IgZnVuY3Rpb24gd2hpY2ggaXMgY2FsbGVkIGJldHdlZW4gZWFjaCBhc3luYyBzdGVwIG9mIHRoZSB1cGxvYWQKICogcHJvY2Vzcy4KICogQHBhcmFtIHtzdHJpbmd9IGlucHV0SWQgRWxlbWVudCBJRCBvZiB0aGUgaW5wdXQgZmlsZSBwaWNrZXIgZWxlbWVudC4KICogQHBhcmFtIHtzdHJpbmd9IG91dHB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIG91dHB1dCBkaXNwbGF5LgogKiBAcmV0dXJuIHshSXRlcmFibGU8IU9iamVjdD59IEl0ZXJhYmxlIG9mIG5leHQgc3RlcHMuCiAqLwpmdW5jdGlvbiogdXBsb2FkRmlsZXNTdGVwKGlucHV0SWQsIG91dHB1dElkKSB7CiAgY29uc3QgaW5wdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQoaW5wdXRJZCk7CiAgaW5wdXRFbGVtZW50LmRpc2FibGVkID0gZmFsc2U7CgogIGNvbnN0IG91dHB1dEVsZW1lbnQgPSBkb2N1bWVudC5nZXRFbGVtZW50QnlJZChvdXRwdXRJZCk7CiAgb3V0cHV0RWxlbWVudC5pbm5lckhUTUwgPSAnJzsKCiAgY29uc3QgcGlja2VkUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICBpbnB1dEVsZW1lbnQuYWRkRXZlbnRMaXN0ZW5lcignY2hhbmdlJywgKGUpID0+IHsKICAgICAgcmVzb2x2ZShlLnRhcmdldC5maWxlcyk7CiAgICB9KTsKICB9KTsKCiAgY29uc3QgY2FuY2VsID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnYnV0dG9uJyk7CiAgaW5wdXRFbGVtZW50LnBhcmVudEVsZW1lbnQuYXBwZW5kQ2hpbGQoY2FuY2VsKTsKICBjYW5jZWwudGV4dENvbnRlbnQgPSAnQ2FuY2VsIHVwbG9hZCc7CiAgY29uc3QgY2FuY2VsUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICBjYW5jZWwub25jbGljayA9ICgpID0+IHsKICAgICAgcmVzb2x2ZShudWxsKTsKICAgIH07CiAgfSk7CgogIC8vIENhbmNlbCB1cGxvYWQgaWYgdXNlciBoYXNuJ3QgcGlja2VkIGFueXRoaW5nIGluIHRpbWVvdXQuCiAgY29uc3QgdGltZW91dFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgc2V0VGltZW91dCgoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9LCBGSUxFX0NIQU5HRV9USU1FT1VUX01TKTsKICB9KTsKCiAgLy8gV2FpdCBmb3IgdGhlIHVzZXIgdG8gcGljayB0aGUgZmlsZXMuCiAgY29uc3QgZmlsZXMgPSB5aWVsZCB7CiAgICBwcm9taXNlOiBQcm9taXNlLnJhY2UoW3BpY2tlZFByb21pc2UsIHRpbWVvdXRQcm9taXNlLCBjYW5jZWxQcm9taXNlXSksCiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdzdGFydGluZycsCiAgICB9CiAgfTsKCiAgaWYgKCFmaWxlcykgewogICAgcmV0dXJuIHsKICAgICAgcmVzcG9uc2U6IHsKICAgICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICAgIH0KICAgIH07CiAgfQoKICBjYW5jZWwucmVtb3ZlKCk7CgogIC8vIERpc2FibGUgdGhlIGlucHV0IGVsZW1lbnQgc2luY2UgZnVydGhlciBwaWNrcyBhcmUgbm90IGFsbG93ZWQuCiAgaW5wdXRFbGVtZW50LmRpc2FibGVkID0gdHJ1ZTsKCiAgZm9yIChjb25zdCBmaWxlIG9mIGZpbGVzKSB7CiAgICBjb25zdCBsaSA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2xpJyk7CiAgICBsaS5hcHBlbmQoc3BhbihmaWxlLm5hbWUsIHtmb250V2VpZ2h0OiAnYm9sZCd9KSk7CiAgICBsaS5hcHBlbmQoc3BhbigKICAgICAgICBgKCR7ZmlsZS50eXBlIHx8ICduL2EnfSkgLSAke2ZpbGUuc2l6ZX0gYnl0ZXMsIGAgKwogICAgICAgIGBsYXN0IG1vZGlmaWVkOiAkewogICAgICAgICAgICBmaWxlLmxhc3RNb2RpZmllZERhdGUgPyBmaWxlLmxhc3RNb2RpZmllZERhdGUudG9Mb2NhbGVEYXRlU3RyaW5nKCkgOgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAnbi9hJ30gLSBgKSk7CiAgICBjb25zdCBwZXJjZW50ID0gc3BhbignMCUgZG9uZScpOwogICAgbGkuYXBwZW5kQ2hpbGQocGVyY2VudCk7CgogICAgb3V0cHV0RWxlbWVudC5hcHBlbmRDaGlsZChsaSk7CgogICAgY29uc3QgZmlsZURhdGFQcm9taXNlID0gbmV3IFByb21pc2UoKHJlc29sdmUpID0+IHsKICAgICAgY29uc3QgcmVhZGVyID0gbmV3IEZpbGVSZWFkZXIoKTsKICAgICAgcmVhZGVyLm9ubG9hZCA9IChlKSA9PiB7CiAgICAgICAgcmVzb2x2ZShlLnRhcmdldC5yZXN1bHQpOwogICAgICB9OwogICAgICByZWFkZXIucmVhZEFzQXJyYXlCdWZmZXIoZmlsZSk7CiAgICB9KTsKICAgIC8vIFdhaXQgZm9yIHRoZSBkYXRhIHRvIGJlIHJlYWR5LgogICAgbGV0IGZpbGVEYXRhID0geWllbGQgewogICAgICBwcm9taXNlOiBmaWxlRGF0YVByb21pc2UsCiAgICAgIHJlc3BvbnNlOiB7CiAgICAgICAgYWN0aW9uOiAnY29udGludWUnLAogICAgICB9CiAgICB9OwoKICAgIC8vIFVzZSBhIGNodW5rZWQgc2VuZGluZyB0byBhdm9pZCBtZXNzYWdlIHNpemUgbGltaXRzLiBTZWUgYi82MjExNTY2MC4KICAgIGxldCBwb3NpdGlvbiA9IDA7CiAgICB3aGlsZSAocG9zaXRpb24gPCBmaWxlRGF0YS5ieXRlTGVuZ3RoKSB7CiAgICAgIGNvbnN0IGxlbmd0aCA9IE1hdGgubWluKGZpbGVEYXRhLmJ5dGVMZW5ndGggLSBwb3NpdGlvbiwgTUFYX1BBWUxPQURfU0laRSk7CiAgICAgIGNvbnN0IGNodW5rID0gbmV3IFVpbnQ4QXJyYXkoZmlsZURhdGEsIHBvc2l0aW9uLCBsZW5ndGgpOwogICAgICBwb3NpdGlvbiArPSBsZW5ndGg7CgogICAgICBjb25zdCBiYXNlNjQgPSBidG9hKFN0cmluZy5mcm9tQ2hhckNvZGUuYXBwbHkobnVsbCwgY2h1bmspKTsKICAgICAgeWllbGQgewogICAgICAgIHJlc3BvbnNlOiB7CiAgICAgICAgICBhY3Rpb246ICdhcHBlbmQnLAogICAgICAgICAgZmlsZTogZmlsZS5uYW1lLAogICAgICAgICAgZGF0YTogYmFzZTY0LAogICAgICAgIH0sCiAgICAgIH07CiAgICAgIHBlcmNlbnQudGV4dENvbnRlbnQgPQogICAgICAgICAgYCR7TWF0aC5yb3VuZCgocG9zaXRpb24gLyBmaWxlRGF0YS5ieXRlTGVuZ3RoKSAqIDEwMCl9JSBkb25lYDsKICAgIH0KICB9CgogIC8vIEFsbCBkb25lLgogIHlpZWxkIHsKICAgIHJlc3BvbnNlOiB7CiAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgIH0KICB9Owp9CgpzY29wZS5nb29nbGUgPSBzY29wZS5nb29nbGUgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYiA9IHNjb3BlLmdvb2dsZS5jb2xhYiB8fCB7fTsKc2NvcGUuZ29vZ2xlLmNvbGFiLl9maWxlcyA9IHsKICBfdXBsb2FkRmlsZXMsCiAgX3VwbG9hZEZpbGVzQ29udGludWUsCn07Cn0pKHNlbGYpOwo=",
              "ok": true,
              "headers": [
                [
                  "content-type",
                  "application/javascript"
                ]
              ],
              "status": 200,
              "status_text": ""
            }
          },
          "base_uri": "https://localhost:8080/",
          "height": 71
        },
        "outputId": "a5bdaed5-49f7-4afc-f0d3-d6ee26354034"
      },
      "source": [
        "from google.colab import files\n",
        "uploaded = files.upload()"
      ],
      "execution_count": 10,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/html": [
              "\n",
              "     <input type=\"file\" id=\"files-ce8f97ba-62b6-4e52-9128-db9ce2e163a9\" name=\"files[]\" multiple disabled />\n",
              "     <output id=\"result-ce8f97ba-62b6-4e52-9128-db9ce2e163a9\">\n",
              "      Upload widget is only available when the cell has been executed in the\n",
              "      current browser session. Please rerun this cell to enable.\n",
              "      </output>\n",
              "      <script src=\"/nbextensions/google.colab/files.js\"></script> "
            ],
            "text/plain": [
              "<IPython.core.display.HTML object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Saving 500_with_vectors.csv to 500_with_vectors.csv\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VDQH4khC0vCY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "df = pd.read_csv('500_with_vectors.csv')\n",
        "df = df.drop(['Unnamed: 0'], axis=1)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Cn5JzCJ14AMZ",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 541
        },
        "outputId": "55d01902-ab9f-419b-84d7-c0e6f9fe4f25"
      },
      "source": [
        "df"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>sentence</th>\n",
              "      <th>label</th>\n",
              "      <th>split</th>\n",
              "      <th>0</th>\n",
              "      <th>1</th>\n",
              "      <th>2</th>\n",
              "      <th>3</th>\n",
              "      <th>4</th>\n",
              "      <th>5</th>\n",
              "      <th>6</th>\n",
              "      <th>7</th>\n",
              "      <th>8</th>\n",
              "      <th>9</th>\n",
              "      <th>10</th>\n",
              "      <th>11</th>\n",
              "      <th>12</th>\n",
              "      <th>13</th>\n",
              "      <th>14</th>\n",
              "      <th>15</th>\n",
              "      <th>16</th>\n",
              "      <th>17</th>\n",
              "      <th>18</th>\n",
              "      <th>19</th>\n",
              "      <th>20</th>\n",
              "      <th>21</th>\n",
              "      <th>22</th>\n",
              "      <th>23</th>\n",
              "      <th>24</th>\n",
              "      <th>25</th>\n",
              "      <th>26</th>\n",
              "      <th>27</th>\n",
              "      <th>28</th>\n",
              "      <th>29</th>\n",
              "      <th>30</th>\n",
              "      <th>31</th>\n",
              "      <th>32</th>\n",
              "      <th>33</th>\n",
              "      <th>34</th>\n",
              "      <th>35</th>\n",
              "      <th>36</th>\n",
              "      <th>...</th>\n",
              "      <th>728</th>\n",
              "      <th>729</th>\n",
              "      <th>730</th>\n",
              "      <th>731</th>\n",
              "      <th>732</th>\n",
              "      <th>733</th>\n",
              "      <th>734</th>\n",
              "      <th>735</th>\n",
              "      <th>736</th>\n",
              "      <th>737</th>\n",
              "      <th>738</th>\n",
              "      <th>739</th>\n",
              "      <th>740</th>\n",
              "      <th>741</th>\n",
              "      <th>742</th>\n",
              "      <th>743</th>\n",
              "      <th>744</th>\n",
              "      <th>745</th>\n",
              "      <th>746</th>\n",
              "      <th>747</th>\n",
              "      <th>748</th>\n",
              "      <th>749</th>\n",
              "      <th>750</th>\n",
              "      <th>751</th>\n",
              "      <th>752</th>\n",
              "      <th>753</th>\n",
              "      <th>754</th>\n",
              "      <th>755</th>\n",
              "      <th>756</th>\n",
              "      <th>757</th>\n",
              "      <th>758</th>\n",
              "      <th>759</th>\n",
              "      <th>760</th>\n",
              "      <th>761</th>\n",
              "      <th>762</th>\n",
              "      <th>763</th>\n",
              "      <th>764</th>\n",
              "      <th>765</th>\n",
              "      <th>766</th>\n",
              "      <th>767</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>new</td>\n",
              "      <td>1</td>\n",
              "      <td>train</td>\n",
              "      <td>0.233748</td>\n",
              "      <td>-0.222612</td>\n",
              "      <td>0.270256</td>\n",
              "      <td>-0.100287</td>\n",
              "      <td>-0.300720</td>\n",
              "      <td>-0.325703</td>\n",
              "      <td>0.098270</td>\n",
              "      <td>0.133757</td>\n",
              "      <td>0.152318</td>\n",
              "      <td>-0.164728</td>\n",
              "      <td>0.059485</td>\n",
              "      <td>0.067821</td>\n",
              "      <td>0.154655</td>\n",
              "      <td>0.054548</td>\n",
              "      <td>0.070726</td>\n",
              "      <td>0.094673</td>\n",
              "      <td>0.025708</td>\n",
              "      <td>0.182993</td>\n",
              "      <td>-0.061887</td>\n",
              "      <td>-0.094307</td>\n",
              "      <td>0.200075</td>\n",
              "      <td>0.183657</td>\n",
              "      <td>0.371726</td>\n",
              "      <td>0.157151</td>\n",
              "      <td>-0.303662</td>\n",
              "      <td>-0.008491</td>\n",
              "      <td>0.088826</td>\n",
              "      <td>0.353364</td>\n",
              "      <td>0.249312</td>\n",
              "      <td>0.121814</td>\n",
              "      <td>-0.034527</td>\n",
              "      <td>0.101452</td>\n",
              "      <td>-0.247590</td>\n",
              "      <td>-0.023922</td>\n",
              "      <td>0.117177</td>\n",
              "      <td>0.410878</td>\n",
              "      <td>0.230009</td>\n",
              "      <td>...</td>\n",
              "      <td>0.071863</td>\n",
              "      <td>-0.152217</td>\n",
              "      <td>0.314776</td>\n",
              "      <td>0.201787</td>\n",
              "      <td>-0.221472</td>\n",
              "      <td>-0.323454</td>\n",
              "      <td>0.174078</td>\n",
              "      <td>0.200405</td>\n",
              "      <td>-0.265705</td>\n",
              "      <td>-0.298790</td>\n",
              "      <td>0.070027</td>\n",
              "      <td>0.126764</td>\n",
              "      <td>-0.550447</td>\n",
              "      <td>-0.102870</td>\n",
              "      <td>-0.299570</td>\n",
              "      <td>-0.029691</td>\n",
              "      <td>0.404164</td>\n",
              "      <td>-0.019329</td>\n",
              "      <td>0.571036</td>\n",
              "      <td>0.190634</td>\n",
              "      <td>0.486656</td>\n",
              "      <td>0.551100</td>\n",
              "      <td>-0.386007</td>\n",
              "      <td>0.553028</td>\n",
              "      <td>-1.911065</td>\n",
              "      <td>-0.058988</td>\n",
              "      <td>0.323301</td>\n",
              "      <td>-0.237405</td>\n",
              "      <td>-0.305501</td>\n",
              "      <td>0.051247</td>\n",
              "      <td>0.153581</td>\n",
              "      <td>-0.213971</td>\n",
              "      <td>0.118741</td>\n",
              "      <td>0.043862</td>\n",
              "      <td>0.268415</td>\n",
              "      <td>-0.302601</td>\n",
              "      <td>0.125339</td>\n",
              "      <td>-0.080951</td>\n",
              "      <td>0.739329</td>\n",
              "      <td>-0.114249</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>splash</td>\n",
              "      <td>1</td>\n",
              "      <td>train</td>\n",
              "      <td>0.228766</td>\n",
              "      <td>-0.213051</td>\n",
              "      <td>0.269835</td>\n",
              "      <td>-0.105115</td>\n",
              "      <td>-0.287857</td>\n",
              "      <td>-0.327904</td>\n",
              "      <td>0.094252</td>\n",
              "      <td>0.142503</td>\n",
              "      <td>0.132024</td>\n",
              "      <td>-0.151163</td>\n",
              "      <td>0.041531</td>\n",
              "      <td>0.040875</td>\n",
              "      <td>0.138779</td>\n",
              "      <td>0.078235</td>\n",
              "      <td>0.106438</td>\n",
              "      <td>0.086404</td>\n",
              "      <td>0.027244</td>\n",
              "      <td>0.184633</td>\n",
              "      <td>-0.068678</td>\n",
              "      <td>-0.089802</td>\n",
              "      <td>0.220715</td>\n",
              "      <td>0.194699</td>\n",
              "      <td>0.361071</td>\n",
              "      <td>0.150054</td>\n",
              "      <td>-0.308772</td>\n",
              "      <td>-0.010095</td>\n",
              "      <td>0.081059</td>\n",
              "      <td>0.379280</td>\n",
              "      <td>0.270694</td>\n",
              "      <td>0.124302</td>\n",
              "      <td>-0.009915</td>\n",
              "      <td>0.101524</td>\n",
              "      <td>-0.274123</td>\n",
              "      <td>-0.001704</td>\n",
              "      <td>0.120562</td>\n",
              "      <td>0.417391</td>\n",
              "      <td>0.221819</td>\n",
              "      <td>...</td>\n",
              "      <td>0.083690</td>\n",
              "      <td>-0.127704</td>\n",
              "      <td>0.329485</td>\n",
              "      <td>0.195718</td>\n",
              "      <td>-0.229896</td>\n",
              "      <td>-0.313575</td>\n",
              "      <td>0.182254</td>\n",
              "      <td>0.225073</td>\n",
              "      <td>-0.247905</td>\n",
              "      <td>-0.318933</td>\n",
              "      <td>0.059433</td>\n",
              "      <td>0.151987</td>\n",
              "      <td>-0.526197</td>\n",
              "      <td>-0.111069</td>\n",
              "      <td>-0.309602</td>\n",
              "      <td>-0.072512</td>\n",
              "      <td>0.383994</td>\n",
              "      <td>-0.030114</td>\n",
              "      <td>0.589379</td>\n",
              "      <td>0.152922</td>\n",
              "      <td>0.489214</td>\n",
              "      <td>0.548229</td>\n",
              "      <td>-0.347108</td>\n",
              "      <td>0.546480</td>\n",
              "      <td>-1.840378</td>\n",
              "      <td>-0.056731</td>\n",
              "      <td>0.301595</td>\n",
              "      <td>-0.239712</td>\n",
              "      <td>-0.268412</td>\n",
              "      <td>0.031907</td>\n",
              "      <td>0.147502</td>\n",
              "      <td>-0.197948</td>\n",
              "      <td>0.107488</td>\n",
              "      <td>0.052756</td>\n",
              "      <td>0.274319</td>\n",
              "      <td>-0.344441</td>\n",
              "      <td>0.114706</td>\n",
              "      <td>-0.092937</td>\n",
              "      <td>0.752795</td>\n",
              "      <td>-0.112522</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>a splash</td>\n",
              "      <td>1</td>\n",
              "      <td>train</td>\n",
              "      <td>0.137208</td>\n",
              "      <td>-0.313407</td>\n",
              "      <td>0.226146</td>\n",
              "      <td>-0.145098</td>\n",
              "      <td>-0.333741</td>\n",
              "      <td>-0.297151</td>\n",
              "      <td>0.106797</td>\n",
              "      <td>0.168328</td>\n",
              "      <td>0.101784</td>\n",
              "      <td>-0.184324</td>\n",
              "      <td>0.039666</td>\n",
              "      <td>0.053048</td>\n",
              "      <td>0.102834</td>\n",
              "      <td>0.053712</td>\n",
              "      <td>0.111207</td>\n",
              "      <td>0.077887</td>\n",
              "      <td>0.088367</td>\n",
              "      <td>0.120244</td>\n",
              "      <td>-0.025574</td>\n",
              "      <td>-0.164704</td>\n",
              "      <td>0.202982</td>\n",
              "      <td>0.149291</td>\n",
              "      <td>0.322254</td>\n",
              "      <td>0.122163</td>\n",
              "      <td>-0.323041</td>\n",
              "      <td>-0.001579</td>\n",
              "      <td>0.060177</td>\n",
              "      <td>0.414324</td>\n",
              "      <td>0.282133</td>\n",
              "      <td>0.062029</td>\n",
              "      <td>-0.011203</td>\n",
              "      <td>0.122134</td>\n",
              "      <td>-0.283472</td>\n",
              "      <td>-0.024014</td>\n",
              "      <td>0.099437</td>\n",
              "      <td>0.403131</td>\n",
              "      <td>0.209398</td>\n",
              "      <td>...</td>\n",
              "      <td>0.038631</td>\n",
              "      <td>-0.081034</td>\n",
              "      <td>0.296371</td>\n",
              "      <td>0.220180</td>\n",
              "      <td>-0.179793</td>\n",
              "      <td>-0.227891</td>\n",
              "      <td>0.124991</td>\n",
              "      <td>0.270777</td>\n",
              "      <td>-0.279706</td>\n",
              "      <td>-0.341659</td>\n",
              "      <td>0.121395</td>\n",
              "      <td>0.146233</td>\n",
              "      <td>-0.450034</td>\n",
              "      <td>-0.127239</td>\n",
              "      <td>-0.290212</td>\n",
              "      <td>-0.045135</td>\n",
              "      <td>0.410151</td>\n",
              "      <td>-0.031487</td>\n",
              "      <td>0.610341</td>\n",
              "      <td>0.229661</td>\n",
              "      <td>0.499684</td>\n",
              "      <td>0.526126</td>\n",
              "      <td>-0.319516</td>\n",
              "      <td>0.539215</td>\n",
              "      <td>-2.143089</td>\n",
              "      <td>-0.110817</td>\n",
              "      <td>0.372068</td>\n",
              "      <td>-0.199167</td>\n",
              "      <td>-0.221969</td>\n",
              "      <td>0.033992</td>\n",
              "      <td>0.137732</td>\n",
              "      <td>-0.277083</td>\n",
              "      <td>0.125730</td>\n",
              "      <td>0.019359</td>\n",
              "      <td>0.279625</td>\n",
              "      <td>-0.349366</td>\n",
              "      <td>0.141246</td>\n",
              "      <td>-0.061664</td>\n",
              "      <td>0.745740</td>\n",
              "      <td>-0.145962</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>greater</td>\n",
              "      <td>1</td>\n",
              "      <td>train</td>\n",
              "      <td>0.244372</td>\n",
              "      <td>-0.217179</td>\n",
              "      <td>0.269788</td>\n",
              "      <td>-0.094656</td>\n",
              "      <td>-0.301732</td>\n",
              "      <td>-0.294661</td>\n",
              "      <td>0.091070</td>\n",
              "      <td>0.110606</td>\n",
              "      <td>0.174087</td>\n",
              "      <td>-0.179891</td>\n",
              "      <td>0.052742</td>\n",
              "      <td>0.062356</td>\n",
              "      <td>0.154614</td>\n",
              "      <td>0.066787</td>\n",
              "      <td>0.078596</td>\n",
              "      <td>0.127853</td>\n",
              "      <td>0.024617</td>\n",
              "      <td>0.193381</td>\n",
              "      <td>-0.070514</td>\n",
              "      <td>-0.097330</td>\n",
              "      <td>0.234804</td>\n",
              "      <td>0.186447</td>\n",
              "      <td>0.371675</td>\n",
              "      <td>0.142241</td>\n",
              "      <td>-0.281845</td>\n",
              "      <td>0.007323</td>\n",
              "      <td>0.096840</td>\n",
              "      <td>0.356626</td>\n",
              "      <td>0.263522</td>\n",
              "      <td>0.109556</td>\n",
              "      <td>-0.027977</td>\n",
              "      <td>0.111564</td>\n",
              "      <td>-0.236160</td>\n",
              "      <td>-0.023637</td>\n",
              "      <td>0.137945</td>\n",
              "      <td>0.407442</td>\n",
              "      <td>0.244270</td>\n",
              "      <td>...</td>\n",
              "      <td>0.084874</td>\n",
              "      <td>-0.162486</td>\n",
              "      <td>0.300559</td>\n",
              "      <td>0.175735</td>\n",
              "      <td>-0.212014</td>\n",
              "      <td>-0.340711</td>\n",
              "      <td>0.197919</td>\n",
              "      <td>0.228845</td>\n",
              "      <td>-0.266319</td>\n",
              "      <td>-0.290397</td>\n",
              "      <td>0.042575</td>\n",
              "      <td>0.110049</td>\n",
              "      <td>-0.557211</td>\n",
              "      <td>-0.127587</td>\n",
              "      <td>-0.319923</td>\n",
              "      <td>-0.082219</td>\n",
              "      <td>0.399482</td>\n",
              "      <td>-0.043312</td>\n",
              "      <td>0.566519</td>\n",
              "      <td>0.173379</td>\n",
              "      <td>0.497346</td>\n",
              "      <td>0.537554</td>\n",
              "      <td>-0.367088</td>\n",
              "      <td>0.539756</td>\n",
              "      <td>-1.802676</td>\n",
              "      <td>-0.053023</td>\n",
              "      <td>0.306297</td>\n",
              "      <td>-0.248355</td>\n",
              "      <td>-0.305621</td>\n",
              "      <td>0.067563</td>\n",
              "      <td>0.146111</td>\n",
              "      <td>-0.167810</td>\n",
              "      <td>0.103711</td>\n",
              "      <td>0.085224</td>\n",
              "      <td>0.280750</td>\n",
              "      <td>-0.313768</td>\n",
              "      <td>0.111990</td>\n",
              "      <td>-0.080343</td>\n",
              "      <td>0.747839</td>\n",
              "      <td>-0.102310</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>a splash even greater</td>\n",
              "      <td>1</td>\n",
              "      <td>train</td>\n",
              "      <td>0.116057</td>\n",
              "      <td>-0.320084</td>\n",
              "      <td>0.254907</td>\n",
              "      <td>-0.138818</td>\n",
              "      <td>-0.379596</td>\n",
              "      <td>-0.254312</td>\n",
              "      <td>0.114805</td>\n",
              "      <td>0.132261</td>\n",
              "      <td>0.120034</td>\n",
              "      <td>-0.178364</td>\n",
              "      <td>0.047495</td>\n",
              "      <td>0.056912</td>\n",
              "      <td>0.107906</td>\n",
              "      <td>0.044411</td>\n",
              "      <td>0.092156</td>\n",
              "      <td>0.090518</td>\n",
              "      <td>0.086257</td>\n",
              "      <td>0.132753</td>\n",
              "      <td>-0.039860</td>\n",
              "      <td>-0.192983</td>\n",
              "      <td>0.247434</td>\n",
              "      <td>0.179486</td>\n",
              "      <td>0.325996</td>\n",
              "      <td>0.141971</td>\n",
              "      <td>-0.282370</td>\n",
              "      <td>0.046241</td>\n",
              "      <td>0.103692</td>\n",
              "      <td>0.376202</td>\n",
              "      <td>0.291703</td>\n",
              "      <td>0.058385</td>\n",
              "      <td>-0.024277</td>\n",
              "      <td>0.120918</td>\n",
              "      <td>-0.214091</td>\n",
              "      <td>-0.055446</td>\n",
              "      <td>0.115729</td>\n",
              "      <td>0.371713</td>\n",
              "      <td>0.215142</td>\n",
              "      <td>...</td>\n",
              "      <td>0.006576</td>\n",
              "      <td>-0.110749</td>\n",
              "      <td>0.246992</td>\n",
              "      <td>0.248963</td>\n",
              "      <td>-0.131180</td>\n",
              "      <td>-0.226813</td>\n",
              "      <td>0.143871</td>\n",
              "      <td>0.223728</td>\n",
              "      <td>-0.269268</td>\n",
              "      <td>-0.295310</td>\n",
              "      <td>0.118405</td>\n",
              "      <td>0.100665</td>\n",
              "      <td>-0.425310</td>\n",
              "      <td>-0.131608</td>\n",
              "      <td>-0.270887</td>\n",
              "      <td>-0.027740</td>\n",
              "      <td>0.406834</td>\n",
              "      <td>-0.056763</td>\n",
              "      <td>0.584286</td>\n",
              "      <td>0.262561</td>\n",
              "      <td>0.473144</td>\n",
              "      <td>0.512692</td>\n",
              "      <td>-0.325140</td>\n",
              "      <td>0.511781</td>\n",
              "      <td>-2.274318</td>\n",
              "      <td>-0.107610</td>\n",
              "      <td>0.371740</td>\n",
              "      <td>-0.189074</td>\n",
              "      <td>-0.268028</td>\n",
              "      <td>0.013178</td>\n",
              "      <td>0.124603</td>\n",
              "      <td>-0.250946</td>\n",
              "      <td>0.128571</td>\n",
              "      <td>0.057087</td>\n",
              "      <td>0.300056</td>\n",
              "      <td>-0.298960</td>\n",
              "      <td>0.138990</td>\n",
              "      <td>-0.040336</td>\n",
              "      <td>0.707032</td>\n",
              "      <td>-0.118647</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>...</th>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "      <td>...</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>495</th>\n",
              "      <td>made me unintentionally famous</td>\n",
              "      <td>1</td>\n",
              "      <td>train</td>\n",
              "      <td>0.201728</td>\n",
              "      <td>-0.260246</td>\n",
              "      <td>0.331977</td>\n",
              "      <td>-0.118236</td>\n",
              "      <td>-0.398582</td>\n",
              "      <td>-0.277229</td>\n",
              "      <td>0.182984</td>\n",
              "      <td>0.154444</td>\n",
              "      <td>0.164846</td>\n",
              "      <td>-0.049436</td>\n",
              "      <td>0.089517</td>\n",
              "      <td>0.044725</td>\n",
              "      <td>0.124052</td>\n",
              "      <td>-0.062668</td>\n",
              "      <td>-0.049897</td>\n",
              "      <td>0.030618</td>\n",
              "      <td>0.121208</td>\n",
              "      <td>0.165681</td>\n",
              "      <td>-0.051339</td>\n",
              "      <td>-0.153854</td>\n",
              "      <td>0.216104</td>\n",
              "      <td>0.275640</td>\n",
              "      <td>0.274642</td>\n",
              "      <td>0.161544</td>\n",
              "      <td>-0.215160</td>\n",
              "      <td>0.098490</td>\n",
              "      <td>0.141741</td>\n",
              "      <td>0.306120</td>\n",
              "      <td>0.200891</td>\n",
              "      <td>0.096325</td>\n",
              "      <td>-0.021883</td>\n",
              "      <td>0.044556</td>\n",
              "      <td>-0.041244</td>\n",
              "      <td>-0.083654</td>\n",
              "      <td>0.002887</td>\n",
              "      <td>0.237949</td>\n",
              "      <td>0.214569</td>\n",
              "      <td>...</td>\n",
              "      <td>0.075177</td>\n",
              "      <td>-0.229180</td>\n",
              "      <td>0.203458</td>\n",
              "      <td>0.377879</td>\n",
              "      <td>-0.125026</td>\n",
              "      <td>-0.281630</td>\n",
              "      <td>0.036875</td>\n",
              "      <td>-0.001598</td>\n",
              "      <td>-0.178603</td>\n",
              "      <td>-0.218029</td>\n",
              "      <td>0.079146</td>\n",
              "      <td>0.174302</td>\n",
              "      <td>-0.402834</td>\n",
              "      <td>-0.117914</td>\n",
              "      <td>-0.349784</td>\n",
              "      <td>0.024555</td>\n",
              "      <td>0.332719</td>\n",
              "      <td>-0.019121</td>\n",
              "      <td>0.625054</td>\n",
              "      <td>0.299865</td>\n",
              "      <td>0.467325</td>\n",
              "      <td>0.474676</td>\n",
              "      <td>-0.439971</td>\n",
              "      <td>0.443952</td>\n",
              "      <td>-1.669511</td>\n",
              "      <td>-0.058426</td>\n",
              "      <td>0.311317</td>\n",
              "      <td>-0.106407</td>\n",
              "      <td>-0.399833</td>\n",
              "      <td>-0.115167</td>\n",
              "      <td>0.190142</td>\n",
              "      <td>-0.268151</td>\n",
              "      <td>0.212878</td>\n",
              "      <td>-0.015738</td>\n",
              "      <td>0.268094</td>\n",
              "      <td>-0.191102</td>\n",
              "      <td>0.140078</td>\n",
              "      <td>0.015779</td>\n",
              "      <td>0.554127</td>\n",
              "      <td>-0.066951</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>496</th>\n",
              "      <td>queasy stomached</td>\n",
              "      <td>0</td>\n",
              "      <td>train</td>\n",
              "      <td>0.182803</td>\n",
              "      <td>-0.242999</td>\n",
              "      <td>0.245358</td>\n",
              "      <td>-0.096413</td>\n",
              "      <td>-0.355837</td>\n",
              "      <td>-0.254806</td>\n",
              "      <td>0.102751</td>\n",
              "      <td>0.145076</td>\n",
              "      <td>0.127309</td>\n",
              "      <td>-0.103663</td>\n",
              "      <td>0.036153</td>\n",
              "      <td>0.027669</td>\n",
              "      <td>0.130667</td>\n",
              "      <td>0.019470</td>\n",
              "      <td>0.007506</td>\n",
              "      <td>0.050136</td>\n",
              "      <td>0.055982</td>\n",
              "      <td>0.155171</td>\n",
              "      <td>-0.002286</td>\n",
              "      <td>-0.139393</td>\n",
              "      <td>0.225504</td>\n",
              "      <td>0.198623</td>\n",
              "      <td>0.301767</td>\n",
              "      <td>0.128633</td>\n",
              "      <td>-0.250357</td>\n",
              "      <td>0.022528</td>\n",
              "      <td>0.086855</td>\n",
              "      <td>0.392883</td>\n",
              "      <td>0.246813</td>\n",
              "      <td>0.096819</td>\n",
              "      <td>-0.023480</td>\n",
              "      <td>0.110337</td>\n",
              "      <td>-0.185586</td>\n",
              "      <td>-0.099878</td>\n",
              "      <td>0.069517</td>\n",
              "      <td>0.336524</td>\n",
              "      <td>0.202507</td>\n",
              "      <td>...</td>\n",
              "      <td>0.087723</td>\n",
              "      <td>-0.135970</td>\n",
              "      <td>0.264125</td>\n",
              "      <td>0.278413</td>\n",
              "      <td>-0.143147</td>\n",
              "      <td>-0.321834</td>\n",
              "      <td>0.141283</td>\n",
              "      <td>0.187914</td>\n",
              "      <td>-0.211738</td>\n",
              "      <td>-0.256016</td>\n",
              "      <td>0.076508</td>\n",
              "      <td>0.133364</td>\n",
              "      <td>-0.480117</td>\n",
              "      <td>-0.103416</td>\n",
              "      <td>-0.305630</td>\n",
              "      <td>0.012182</td>\n",
              "      <td>0.353097</td>\n",
              "      <td>-0.060268</td>\n",
              "      <td>0.550501</td>\n",
              "      <td>0.251482</td>\n",
              "      <td>0.442724</td>\n",
              "      <td>0.487369</td>\n",
              "      <td>-0.408622</td>\n",
              "      <td>0.510985</td>\n",
              "      <td>-1.836359</td>\n",
              "      <td>-0.076481</td>\n",
              "      <td>0.289726</td>\n",
              "      <td>-0.199641</td>\n",
              "      <td>-0.320311</td>\n",
              "      <td>0.001357</td>\n",
              "      <td>0.179691</td>\n",
              "      <td>-0.234576</td>\n",
              "      <td>0.127190</td>\n",
              "      <td>0.048551</td>\n",
              "      <td>0.263170</td>\n",
              "      <td>-0.247456</td>\n",
              "      <td>0.093353</td>\n",
              "      <td>-0.018356</td>\n",
              "      <td>0.687621</td>\n",
              "      <td>-0.089248</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>497</th>\n",
              "      <td>theater</td>\n",
              "      <td>1</td>\n",
              "      <td>train</td>\n",
              "      <td>0.209967</td>\n",
              "      <td>-0.213674</td>\n",
              "      <td>0.248412</td>\n",
              "      <td>-0.119370</td>\n",
              "      <td>-0.267879</td>\n",
              "      <td>-0.315400</td>\n",
              "      <td>0.080620</td>\n",
              "      <td>0.130779</td>\n",
              "      <td>0.138751</td>\n",
              "      <td>-0.150797</td>\n",
              "      <td>0.055821</td>\n",
              "      <td>0.063313</td>\n",
              "      <td>0.128407</td>\n",
              "      <td>0.083175</td>\n",
              "      <td>0.055759</td>\n",
              "      <td>0.093520</td>\n",
              "      <td>0.024823</td>\n",
              "      <td>0.183015</td>\n",
              "      <td>-0.050592</td>\n",
              "      <td>-0.086988</td>\n",
              "      <td>0.169125</td>\n",
              "      <td>0.156934</td>\n",
              "      <td>0.375840</td>\n",
              "      <td>0.164646</td>\n",
              "      <td>-0.313982</td>\n",
              "      <td>-0.031767</td>\n",
              "      <td>0.066157</td>\n",
              "      <td>0.372339</td>\n",
              "      <td>0.263345</td>\n",
              "      <td>0.109531</td>\n",
              "      <td>-0.035460</td>\n",
              "      <td>0.118850</td>\n",
              "      <td>-0.267497</td>\n",
              "      <td>-0.013979</td>\n",
              "      <td>0.135073</td>\n",
              "      <td>0.386075</td>\n",
              "      <td>0.245364</td>\n",
              "      <td>...</td>\n",
              "      <td>0.092160</td>\n",
              "      <td>-0.149491</td>\n",
              "      <td>0.308301</td>\n",
              "      <td>0.174255</td>\n",
              "      <td>-0.249452</td>\n",
              "      <td>-0.332275</td>\n",
              "      <td>0.191591</td>\n",
              "      <td>0.211993</td>\n",
              "      <td>-0.243143</td>\n",
              "      <td>-0.286586</td>\n",
              "      <td>0.028451</td>\n",
              "      <td>0.152585</td>\n",
              "      <td>-0.544175</td>\n",
              "      <td>-0.102748</td>\n",
              "      <td>-0.260646</td>\n",
              "      <td>-0.012314</td>\n",
              "      <td>0.431076</td>\n",
              "      <td>0.001276</td>\n",
              "      <td>0.595296</td>\n",
              "      <td>0.145947</td>\n",
              "      <td>0.466487</td>\n",
              "      <td>0.549567</td>\n",
              "      <td>-0.350339</td>\n",
              "      <td>0.545786</td>\n",
              "      <td>-2.003801</td>\n",
              "      <td>-0.062118</td>\n",
              "      <td>0.293004</td>\n",
              "      <td>-0.246441</td>\n",
              "      <td>-0.270356</td>\n",
              "      <td>0.064427</td>\n",
              "      <td>0.167575</td>\n",
              "      <td>-0.219617</td>\n",
              "      <td>0.117826</td>\n",
              "      <td>0.057679</td>\n",
              "      <td>0.264792</td>\n",
              "      <td>-0.348748</td>\n",
              "      <td>0.104451</td>\n",
              "      <td>-0.076590</td>\n",
              "      <td>0.768219</td>\n",
              "      <td>-0.136208</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>498</th>\n",
              "      <td>staggered from the theater</td>\n",
              "      <td>0</td>\n",
              "      <td>train</td>\n",
              "      <td>0.101467</td>\n",
              "      <td>-0.311635</td>\n",
              "      <td>0.240607</td>\n",
              "      <td>-0.156736</td>\n",
              "      <td>-0.323171</td>\n",
              "      <td>-0.263926</td>\n",
              "      <td>0.131896</td>\n",
              "      <td>0.169260</td>\n",
              "      <td>0.110856</td>\n",
              "      <td>-0.112883</td>\n",
              "      <td>0.062913</td>\n",
              "      <td>0.069407</td>\n",
              "      <td>0.089386</td>\n",
              "      <td>0.045057</td>\n",
              "      <td>0.014777</td>\n",
              "      <td>0.042182</td>\n",
              "      <td>0.091393</td>\n",
              "      <td>0.090369</td>\n",
              "      <td>-0.015866</td>\n",
              "      <td>-0.207520</td>\n",
              "      <td>0.152777</td>\n",
              "      <td>0.129481</td>\n",
              "      <td>0.321496</td>\n",
              "      <td>0.130120</td>\n",
              "      <td>-0.275001</td>\n",
              "      <td>0.000389</td>\n",
              "      <td>0.065203</td>\n",
              "      <td>0.380027</td>\n",
              "      <td>0.255409</td>\n",
              "      <td>0.019734</td>\n",
              "      <td>-0.005136</td>\n",
              "      <td>0.127422</td>\n",
              "      <td>-0.174967</td>\n",
              "      <td>-0.060010</td>\n",
              "      <td>0.090565</td>\n",
              "      <td>0.326273</td>\n",
              "      <td>0.266211</td>\n",
              "      <td>...</td>\n",
              "      <td>0.038294</td>\n",
              "      <td>-0.108615</td>\n",
              "      <td>0.234072</td>\n",
              "      <td>0.278265</td>\n",
              "      <td>-0.166622</td>\n",
              "      <td>-0.254264</td>\n",
              "      <td>0.073088</td>\n",
              "      <td>0.161864</td>\n",
              "      <td>-0.250137</td>\n",
              "      <td>-0.289731</td>\n",
              "      <td>0.129631</td>\n",
              "      <td>0.159090</td>\n",
              "      <td>-0.362076</td>\n",
              "      <td>-0.124927</td>\n",
              "      <td>-0.234639</td>\n",
              "      <td>0.070827</td>\n",
              "      <td>0.420792</td>\n",
              "      <td>-0.022967</td>\n",
              "      <td>0.604958</td>\n",
              "      <td>0.233506</td>\n",
              "      <td>0.368468</td>\n",
              "      <td>0.498633</td>\n",
              "      <td>-0.359628</td>\n",
              "      <td>0.543543</td>\n",
              "      <td>-2.539260</td>\n",
              "      <td>-0.159995</td>\n",
              "      <td>0.314086</td>\n",
              "      <td>-0.158531</td>\n",
              "      <td>-0.276643</td>\n",
              "      <td>0.011032</td>\n",
              "      <td>0.144199</td>\n",
              "      <td>-0.302733</td>\n",
              "      <td>0.140596</td>\n",
              "      <td>0.013261</td>\n",
              "      <td>0.282099</td>\n",
              "      <td>-0.319144</td>\n",
              "      <td>0.094871</td>\n",
              "      <td>-0.013283</td>\n",
              "      <td>0.696700</td>\n",
              "      <td>-0.137790</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>499</th>\n",
              "      <td>blacked out</td>\n",
              "      <td>0</td>\n",
              "      <td>train</td>\n",
              "      <td>0.251036</td>\n",
              "      <td>-0.233704</td>\n",
              "      <td>0.312284</td>\n",
              "      <td>-0.095550</td>\n",
              "      <td>-0.338005</td>\n",
              "      <td>-0.289469</td>\n",
              "      <td>0.131197</td>\n",
              "      <td>0.133932</td>\n",
              "      <td>0.134819</td>\n",
              "      <td>-0.105469</td>\n",
              "      <td>0.067544</td>\n",
              "      <td>0.038067</td>\n",
              "      <td>0.135163</td>\n",
              "      <td>0.046587</td>\n",
              "      <td>0.055289</td>\n",
              "      <td>0.079387</td>\n",
              "      <td>0.064815</td>\n",
              "      <td>0.169503</td>\n",
              "      <td>-0.056528</td>\n",
              "      <td>-0.132202</td>\n",
              "      <td>0.238173</td>\n",
              "      <td>0.224305</td>\n",
              "      <td>0.313463</td>\n",
              "      <td>0.183716</td>\n",
              "      <td>-0.261935</td>\n",
              "      <td>-0.017931</td>\n",
              "      <td>0.087245</td>\n",
              "      <td>0.377508</td>\n",
              "      <td>0.230760</td>\n",
              "      <td>0.096778</td>\n",
              "      <td>-0.042611</td>\n",
              "      <td>0.098273</td>\n",
              "      <td>-0.174714</td>\n",
              "      <td>-0.073361</td>\n",
              "      <td>0.109003</td>\n",
              "      <td>0.355861</td>\n",
              "      <td>0.205164</td>\n",
              "      <td>...</td>\n",
              "      <td>0.066876</td>\n",
              "      <td>-0.145985</td>\n",
              "      <td>0.269457</td>\n",
              "      <td>0.271397</td>\n",
              "      <td>-0.154132</td>\n",
              "      <td>-0.309572</td>\n",
              "      <td>0.130990</td>\n",
              "      <td>0.153329</td>\n",
              "      <td>-0.254692</td>\n",
              "      <td>-0.279180</td>\n",
              "      <td>0.090058</td>\n",
              "      <td>0.115682</td>\n",
              "      <td>-0.493049</td>\n",
              "      <td>-0.085393</td>\n",
              "      <td>-0.365650</td>\n",
              "      <td>-0.029009</td>\n",
              "      <td>0.337045</td>\n",
              "      <td>-0.036840</td>\n",
              "      <td>0.562763</td>\n",
              "      <td>0.231064</td>\n",
              "      <td>0.461940</td>\n",
              "      <td>0.506443</td>\n",
              "      <td>-0.407299</td>\n",
              "      <td>0.548461</td>\n",
              "      <td>-1.699150</td>\n",
              "      <td>-0.035893</td>\n",
              "      <td>0.301651</td>\n",
              "      <td>-0.175210</td>\n",
              "      <td>-0.311732</td>\n",
              "      <td>-0.020365</td>\n",
              "      <td>0.172715</td>\n",
              "      <td>-0.232078</td>\n",
              "      <td>0.130300</td>\n",
              "      <td>0.042531</td>\n",
              "      <td>0.268110</td>\n",
              "      <td>-0.269198</td>\n",
              "      <td>0.091461</td>\n",
              "      <td>-0.050782</td>\n",
              "      <td>0.674053</td>\n",
              "      <td>-0.103417</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "<p>500 rows × 771 columns</p>\n",
              "</div>"
            ],
            "text/plain": [
              "                           sentence  label  split  ...       765       766       767\n",
              "0                               new      1  train  ... -0.080951  0.739329 -0.114249\n",
              "1                            splash      1  train  ... -0.092937  0.752795 -0.112522\n",
              "2                          a splash      1  train  ... -0.061664  0.745740 -0.145962\n",
              "3                           greater      1  train  ... -0.080343  0.747839 -0.102310\n",
              "4             a splash even greater      1  train  ... -0.040336  0.707032 -0.118647\n",
              "..                              ...    ...    ...  ...       ...       ...       ...\n",
              "495  made me unintentionally famous      1  train  ...  0.015779  0.554127 -0.066951\n",
              "496                queasy stomached      0  train  ... -0.018356  0.687621 -0.089248\n",
              "497                         theater      1  train  ... -0.076590  0.768219 -0.136208\n",
              "498      staggered from the theater      0  train  ... -0.013283  0.696700 -0.137790\n",
              "499                     blacked out      0  train  ... -0.050782  0.674053 -0.103417\n",
              "\n",
              "[500 rows x 771 columns]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 14
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "hiPd3yok4A3i",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}